Skip to content

optimized fused_grouped_topk SYCL kernel for MoE expert routing#253

Open
xiaolong-intel wants to merge 15 commits into
vllm-project:mainfrom
xiaolong-intel:grouped_topk
Open

optimized fused_grouped_topk SYCL kernel for MoE expert routing#253
xiaolong-intel wants to merge 15 commits into
vllm-project:mainfrom
xiaolong-intel:grouped_topk

Conversation

@xiaolong-intel
Copy link
Copy Markdown

@xiaolong-intel xiaolong-intel commented Apr 7, 2026

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS ABOVE HAVE BEEN CONSIDERED.

Purpose

optimized fused_grouped_topk SYCL kernel for MoE expert routing

Test Plan

I wrote test cases in https://github.com/xiaolong-intel/vllm-xpu-kernels/blob/grouped_topk/tests/test_grouped_topk.py. Tested the consistency of the forward_xpu operator with the torch version of grouped_topk on B60

python -m pytest test_grouped_topk.py -v

Test Result

test cases:
image
test results:
image
All test cases passed successfully

Tested operator performance on GPU B60 with the following configuration:

  • num_expert_group = 8
  • topk_group = 4
  • topk = 8
image

(Optional) Documentation Update

BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing (anything written below this line will be removed by GitHub Actions)

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the XPU MoE “grouped topk” path by introducing a substantially revised SYCL implementation for fused_grouped_topk, updating its Torch bindings, and adjusting test coverage to validate correctness against a PyTorch baseline.

Changes:

  • Add _moe_C import to the Python fused MoE interface to ensure MoE ops are registered.
  • Rework csrc/moe/fused_grouped_topk.cpp with a new SYCL kernel implementation and update the Torch op schema formatting.
  • Update grouped-topk tests and baselines (param ranges, determinism settings, and wrapper behavior).

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
vllm_xpu_kernels/fused_moe_interface.py Ensures _moe_C is imported so MoE ops (including grouped topk) are available.
tests/test_grouped_topk.py Expands parameter coverage (notably token counts and expert counts) for grouped topk validation.
tests/ops/grouped_topk_op.py Adjusts the reference implementation for determinism and updates the SYCL wrapper behavior.
csrc/moe/torch_bindings.cpp Reformats/declares the fused_grouped_topk op schema.
csrc/moe/moe_ops.h Minor header formatting around the fused_grouped_topk declaration.
csrc/moe/fused_grouped_topk.cpp Replaces the previous kernel with a new SYCL implementation and host-side dispatch/validation logic.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tests/test_grouped_topk.py Outdated
Comment on lines +21 to +24
@pytest.mark.parametrize("n_token", [64, 50000,100000])
@pytest.mark.parametrize("n_hidden", [1024, 2048])
@pytest.mark.parametrize("n_expert", [16])
@pytest.mark.parametrize("topk", [2])
@pytest.mark.parametrize("n_expert", [128, 256])
@pytest.mark.parametrize("topk", [8])
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default parametrization makes this test extremely large (n_token up to 100000, plus multiple values for n_hidden/n_expert/dtype/etc.), which is likely to cause very long runtimes and/or XPU OOM in the normal (non-MINI) test run. Consider keeping unit-test sizes small by default and gating large shapes behind the existing MINI_PYTEST_PARAMS/CI env, or moving the large-shape coverage to a dedicated benchmark/perf test.

Copilot uses AI. Check for mistakes.
Comment thread tests/ops/grouped_topk_op.py Outdated
Comment thread csrc/moe/fused_grouped_topk.cpp Outdated
Comment on lines +526 to +535
auto input_size = gating_output.sizes();
int64_t num_tokens = input_size[0];
int64_t num_experts = input_size[1];
int64_t n_group = n_expert_group;
int64_t topk_group = n_topk_group;
int64_t topk = n_topk;

TORCH_CHECK(hidden_states.sizes()[0] == gating_output.sizes()[0],
"Number of tokens mismatch");
TORCH_CHECK(input_size.size() == 2, "gating_output must be a 2D Tensor");
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_tokens = input_size[0] / num_experts = input_size[1] is computed before verifying gating_output is 2D, which can read out of bounds for unexpected inputs. Also, the previous implementation validated bias shape matched num_experts; that check is now missing, but the kernel indexes routingBias[...] assuming length == num_experts. Please add TORCH_CHECKs to validate gating_output.dim()==2 before indexing sizes, and if bias is provided ensure it is 1D (or broadcastable in an explicitly supported way) with bias.numel()==num_experts (and ideally on the same device).

Suggested change
auto input_size = gating_output.sizes();
int64_t num_tokens = input_size[0];
int64_t num_experts = input_size[1];
int64_t n_group = n_expert_group;
int64_t topk_group = n_topk_group;
int64_t topk = n_topk;
TORCH_CHECK(hidden_states.sizes()[0] == gating_output.sizes()[0],
"Number of tokens mismatch");
TORCH_CHECK(input_size.size() == 2, "gating_output must be a 2D Tensor");
TORCH_CHECK(gating_output.dim() == 2, "gating_output must be a 2D Tensor");
auto input_size = gating_output.sizes();
int64_t num_tokens = input_size[0];
int64_t num_experts = input_size[1];
int64_t n_group = n_expert_group;
int64_t topk_group = n_topk_group;
int64_t topk = n_topk;
TORCH_CHECK(hidden_states.sizes()[0] == gating_output.sizes()[0],
"Number of tokens mismatch");
if (has_bias) {
TORCH_CHECK(bias->dim() == 1, "bias must be a 1D Tensor");
TORCH_CHECK(bias->numel() == num_experts,
"bias must have num_experts elements");
TORCH_CHECK(bias->device() == gating_output.device(),
"bias must be on the same device as gating_output");
}

Copilot uses AI. Check for mistakes.
Comment thread csrc/moe/fused_grouped_topk.cpp Outdated

constexpr int MaxExpertCandidatesPerLane = NumDeepseekExperts / WARP_SIZE;
T localCandidateScores[MaxExpertCandidatesPerLane];
IdxT localCandidateIdx[MaxExpertCandidatesPerLane];
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the grouped-routing path, selectedExpertScores/Idx are sized to DefaultMaxNumTopExperts (8), but topk is only checked as <= 32. If topk > 8 with UseGroups=true, reduceTopK writes past these arrays. Either enforce topk <= DefaultMaxNumTopExperts for the grouped path, or add additional kernel instantiations / dynamic storage sized for larger topk.

Suggested change
IdxT localCandidateIdx[MaxExpertCandidatesPerLane];
IdxT localCandidateIdx[MaxExpertCandidatesPerLane];
TORCH_CHECK(
topk <= DefaultMaxNumTopExperts,
"Grouped routing only supports topk <= ", DefaultMaxNumTopExperts,
", but got topk=", topk,
". Increase grouped-path storage/kernel support for larger topk values.");

Copilot uses AI. Check for mistakes.
Comment thread csrc/moe/fused_grouped_topk.cpp Outdated
Comment on lines +247 to +267
int32_t totalCandidates = topkGroup * numExpertsPerGroup;
for (int32_t candidate = laneIdx; candidate < totalCandidates;
candidate += WARP_SIZE) {
int32_t localSlot = candidate / WARP_SIZE;
int32_t selectedGroup = candidate / numExpertsPerGroup;
int32_t expertInGroup = candidate % numExpertsPerGroup;
int32_t gid = selectedGroupIdx[selectedGroup];
int32_t idx = gid * numExpertsPerGroup + expertInGroup;
T candidateScore = neg_inf<T>();

T input = scoresToken[idx];
if (is_finite(input)) {
T score = apply_scoring<SF>(input);
candidateScore = score;
if (has_bias) {
candidateScore = candidateScore + sycl_cast<T, BiasT>(routingBias[idx]);
}
}

localCandidateScores[localSlot] = candidateScore;
localCandidateIdx[localSlot] = static_cast<IdxT>(idx);
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

localCandidateScores/localCandidateIdx are sized assuming at most NumDeepseekExperts candidates (256 total, 8 per lane), but totalCandidates = topkGroup * numExpertsPerGroup can exceed 256 for valid inputs (e.g., small numGroup with larger numExpertsPerGroup). When that happens, localSlot = candidate / WARP_SIZE will exceed MaxExpertCandidatesPerLane and write out of bounds. Add a host-side TORCH_CHECK that topkGroup * numExpertsPerGroup <= 256 (or dispatch a kernel variant sized for the actual maximum candidates).

Copilot uses AI. Check for mistakes.
@jikunshang
Copy link
Copy Markdown
Member

what's benefit come from? how do you benchmark?

@xiaolong-intel
Copy link
Copy Markdown
Author

what's benefit come from? how do you benchmark?

Benifit comes from:

  • The old kernel first keeps all experts around, masks out the unwanted groups, and then runs top-k on that larger set. The new kernel first picks the best groups, then only looks at experts inside those groups, and finally does top-k on this much smaller candidate set.
  • In old kernel,a group often contains more than 32 experts, so one warp cannot directly keep all group values in registers and do the full selection in one step. So each lane needs to introduce extra loops.

I used https://github.com/xiaolong-intel/vllm-xpu-kernels/blob/grouped_topk/tests/test_grouped_topk.py. to test the precision and accuracy of the op.
Using torch.profile, tested the kernel_time of two ops under the same input and precision.Calculate ten times and take the average after warmup.

def _run_profile(num_tokens: int, hidden_dim: int, num_experts: int):
    if not torch.xpu.is_available():
        raise RuntimeError("XPU is not available for profiling")
    
    device = DEVICE
    print("Allocating profiling tensors...")

    test_hidden = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16)
    test_gating = torch.randn(num_tokens, num_experts, device=device, dtype=torch.bfloat16)
    test_bias = torch.zeros(num_experts, device=device, dtype=torch.bfloat16)

    print("Profiling Fused Kernel...")
    # Increase warmup and active iterations for better stability
    for _ in range(50):
        _run_op_once(num_tokens, num_experts, test_hidden, test_gating, test_bias)
    torch.xpu.synchronize()

    activities = [torch.profiler.ProfilerActivity.CPU]
    if hasattr(torch.profiler.ProfilerActivity, "XPU"):
        activities.append(torch.profiler.ProfilerActivity.XPU)

    default_sort = (
        "self_xpu_time_total"
        if hasattr(torch.profiler.ProfilerActivity, "XPU")
        else "self_cpu_time_total"
    )
    sort_by = PROFILE_SORT or default_sort
    if PROFILE_TARGET in ("both", "fused"):
        _profile_kernel(
            "fused",
            lambda: _run_op_once(
                num_tokens, hidden_dim, test_hidden, test_gating, test_bias),
            PROFILE_WARMUP,
            PROFILE_ACTIVE,
            PROFILE_REPEAT,
            activities,
            sort_by,
            PROFILE_TRACE_DIR,
        )

    torch.xpu.synchronize()
    print("Profiling Grouped Kernel...")
    if PROFILE_TARGET in ("both", "grouped"):
        _profile_kernel(
            "grouped",
            lambda: _run_grouped_topk_once(
                num_tokens, test_hidden, test_gating, test_bias),
            PROFILE_WARMUP,
            PROFILE_ACTIVE,
            PROFILE_REPEAT,
            activities,
            sort_by,
            PROFILE_TRACE_DIR,
        )

Comment thread tests/ops/grouped_topk_op.py
assert hidden_states.size(0) == gating_output.size(0), (
"Number of tokens mismatch")
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you compared the performance of the softmax scoring_func?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. I tested the kernel_time under softmax scoring_func with seqlenth=100000. The above is the original version, and the below is the optimized version.
I am currently testing with https://github.com/vllm-project/vllm-xpu-kernels/blob/main/benchmark/benchmark_grouped_topk.py , and I will provide the results later.
image

Comment thread vllm_xpu_kernels/fused_moe_interface.py Outdated
try:
from . import _C # noqa: F401
from . import _xpu_C # noqa: F401
from . import _moe_C # noqa: F401
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please do not add unnecessary import.

@mayuyuace
Copy link
Copy Markdown
Collaborator

Please provide benchmark results as benchmark_grouped_topk do.
https://github.com/vllm-project/vllm-xpu-kernels/blob/main/benchmark/benchmark_grouped_topk.py

Comment thread csrc/moe/fused_grouped_topk.cpp Outdated
#include <sycl/sycl.hpp>

#include "../utils.h"
#include <c10/xpu/XPUStream.h>
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please don't include here.
../utils.h should cover this.

Comment thread csrc/moe/fused_grouped_topk.cpp Outdated
return 1.0f / (1.0f + sycl::native::exp(-x));
}
// Type trait: bfloat16 -> float for computation, everything else stays as-is
template <typename T>
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for such util, move to utils.h

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, thank you for your suggestion, I will make the correction.

@xiaolong-intel
Copy link
Copy Markdown
Author

xiaolong-intel commented Apr 8, 2026

Please provide benchmark results as benchmark_grouped_topk do. https://github.com/vllm-project/vllm-xpu-kernels/blob/main/benchmark/benchmark_grouped_topk.py

Hi mayuyuace:
I used https://github.com/vllm-project/vllm-xpu-kernels/blob/main/benchmark/benchmark_grouped_topk.py to test operator performance on the B60, with seq_length using a large length of 50000.
I conducted tests under the DeepSeek conditions [256,8,4,8] and the Nemo conditions [128,1,1,6].
Benchmark.py is a useful tool, thanks for your suggestion. : )
image

@xiaolong-intel xiaolong-intel changed the title Grouped topk optimized fused_grouped_topk SYCL kernel for MoE expert routin Apr 8, 2026
@xiaolong-intel xiaolong-intel changed the title optimized fused_grouped_topk SYCL kernel for MoE expert routin optimized fused_grouped_topk SYCL kernel for MoE expert routing Apr 8, 2026
@xiaolong-intel xiaolong-intel marked this pull request as draft April 8, 2026 07:08
@xiaolong-intel xiaolong-intel marked this pull request as ready for review April 9, 2026 02:47
Copy link
Copy Markdown
Member

@jikunshang jikunshang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls fix pre-commit issue and do not change oneDNN commit.

Comment thread benchmark/benchmark_grouped_topk.py Outdated
n_expert_range = [16, 64, 128]
topk_range = [2, 4]
topk_group_range = [4, 8]
n_token_range = [50000]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please don't remove previous one.


benchmark = get_benchmark()
benchmark.run(print_data=True, save_path=args.save_path)
benchmark.run(print_data=True, save_path=args.save_path) No newline at end of file
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls add blank line

Comment thread csrc/utils.h Outdated
Comment on lines +95 to +96
std::is_same_v<T, sycl::ext::oneapi::bfloat16> ||
std::is_same_v<T, sycl::half>;;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

necessary?

Comment thread tests/test_grouped_topk.py Outdated


@pytest.mark.parametrize("n_token", [1, 33, 64])
@pytest.mark.parametrize("n_token", [64, 50000,100000])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't want use such large shape in CI which may make job longer to complete.

Comment thread tests/test_grouped_topk.py Outdated
@pytest.mark.parametrize("topk_group", [2])
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("topk_group", [4])
@pytest.mark.parametrize("scoring_func", ["sigmoid","softmax"])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please revert unnecessary change. same for blank line

@xiaolong-intel xiaolong-intel marked this pull request as draft April 10, 2026 05:25
@xiaolong-intel xiaolong-intel marked this pull request as ready for review April 10, 2026 05:57
@xiaolong-intel
Copy link
Copy Markdown
Author

oneDNN commit.

pre-commit is completed. onednn. It seems I haven't made any changes😂, just synced with the latest vllm-xpu-kernels.

Comment thread csrc/moe/moe_ops.h
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this for?

Comment thread third_party/oneDNN
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why update onednn?

@mayuyuace
Copy link
Copy Markdown
Collaborator

You may need rebase the branch and only keep the key code changing.
Onednn should not be changed in this PR.

@xiaolong-intel
Copy link
Copy Markdown
Author

You may need rebase the branch and only keep the key code changing. Onednn should not be changed in this PR.

Okay, I understand, thank you

@xiaolong-intel xiaolong-intel marked this pull request as draft April 22, 2026 02:40
@xiaolong-intel xiaolong-intel marked this pull request as ready for review April 22, 2026 03:11
@xinyu-intel
Copy link
Copy Markdown
Collaborator

can you rebase and fix DCO?

@xiaolong-intel xiaolong-intel force-pushed the grouped_topk branch 3 times, most recently from 91fe795 to b6e2254 Compare April 29, 2026 06:39
@xiaolong-intel
Copy link
Copy Markdown
Author

can you rebase and fix DCO?

Done.Thanks

@xiaolong-intel
Copy link
Copy Markdown
Author

Hello everyone:

To validate the kernel's effectiveness in real-world serving scenarios, we benchmarked using NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 on 4x Intel B60 GPUs with vLLM (tensor-parallel-size=4). This model has 52 decoder layers, of which 23 are MoE layers with the following configuration:

  • num_experts = 128
  • n_group = 1
  • topk_group = 1
  • topk = 8
  • scoring_func = sigmoid

Benchmark Setup

All benchmarks use --max-concurrency 2 and --request-rate inf to simulate realistic serving conditions.

Results

TTFT Benchmark (prefill-dominant)

  • Input length: 8192, Output length: 64, 100 prompts
Metric Baseline Optimized Improvement
Mean TTFT (ms) 2340.88 2109.38 -9.9%
Median TTFT (ms) 2339.16 2110.56 -9.8%
P99 TTFT (ms) 2379.89 2116.13 -11.1%

Decode Throughput Benchmark (decode-dominant)

  • Input length: 128, Output length: 512, 50 prompts
Metric Baseline Optimized Improvement
Median TPOT (ms) 49.17 47.16 -4.1%
Output throughput (tok/s) 40.34 41.32 +2.4%
Total throughput (tok/s) 50.43 51.65 +2.4%

Summary

The optimized kernel achieves:

  • ~10% TTFT reduction on long-context prefill (8K tokens)
  • ~2-4% decode throughput improvement

original results are as follows:
image

BUG Report

In addition, I also discovered a bug in the current kernel version.
When expert >= 256, for example kimi expert = 384, at this time the kernel will have an array out-of-bounds access problem

for (int e = 0; e < calc_per_item; ++e) {
      load_elems[e] = kNegInfinity;
      local_idx[e] = -1;
      bias[e] = 0.0f;  // Initialize bias to zero
    }

In this loop, the number of iterations is calc_per_item,int calc_per_item = (experts + sub_group_size - 1) / sub_group_size;,with the default subgroup=32, and calc_per_item is 12 when expert=384

T load_elems[malloc_per_item];
    int local_idx[malloc_per_item];
    T bias[malloc_per_item];

But by definition, the size of load_elems is malloc_per_item. static constexpr int malloc_per_item = MAX_EXPERT_GROUPS;

image When expert_group=1, max_expert_group=8, so malloc_per_item is also 8. So the code will access load_elems[9], [10], [11], thus causing an out-of-bounds error

@xiaolong-intel xiaolong-intel marked this pull request as draft May 26, 2026 06:11
@xiaolong-intel xiaolong-intel marked this pull request as ready for review May 26, 2026 06:36
Signed-off-by: xiaolong <xiaolong.guo@intel.com>
Signed-off-by: root <xiaolong.guo@intel.com>
Signed-off-by: root <xiaolong.guo@intel.com>
Signed-off-by: xiaolong <xiaolong.guo@intel.com>
Signed-off-by: root <xiaolong.guo@intel.com>
Signed-off-by: xiaolong-intel <xiaolong.guo@intel.com>
Signed-off-by: xiaolong <xiaolong.guo@intel.com>
Signed-off-by: root <xiaolong.guo@intel.com>
Signed-off-by: root <xiaolong.guo@intel.com>
Signed-off-by: xiaolong <xiaolong.guo@intel.com>
Signed-off-by: root <xiaolong.guo@intel.com>
Signed-off-by: xiaolong-intel <xiaolong.guo@intel.com>
Signed-off-by: xiaolong <xiaolong.guo@intel.com>
Signed-off-by: root <xiaolong.guo@intel.com>
Signed-off-by: xiaolong-intel <xiaolong.guo@intel.com>
Signed-off-by: xiaolong <xiaolong.guo@intel.com>
Signed-off-by: root <xiaolong.guo@intel.com>
Signed-off-by: root <xiaolong.guo@intel.com>
Signed-off-by: xiaolong <xiaolong.guo@intel.com>
Signed-off-by: root <xiaolong.guo@intel.com>
Signed-off-by: xiaolong-intel <xiaolong.guo@intel.com>
Signed-off-by: xiaolong <xiaolong.guo@intel.com>
Signed-off-by: root <xiaolong.guo@intel.com>
…onversion of renormalization

Signed-off-by: root <xiaolong.guo@intel.com>
Signed-off-by: xiaolong <xiaolong.guo@intel.com>
Signed-off-by: root <xiaolong.guo@intel.com>
Signed-off-by: root <xiaolong.guo@intel.com>

Signed-off-by:  <xiaolong.guo@intel.com>
Signed-off-by: xiaolong <xiaolong.guo@intel.com>
Signed-off-by: root <xiaolong.guo@intel.com>
Signed-off-by: xiaolong <xiaolong.guo@intel.com>
Signed-off-by: root <xiaolong.guo@intel.com>
Signed-off-by: root <xiaolong.guo@intel.com>
Signed-off-by: root <xiaolong.guo@intel.com>
Signed-off-by: root <xiaolong.guo@intel.com>
Signed-off-by: root <xiaolong.guo@intel.com>
@xinyu-intel
Copy link
Copy Markdown
Collaborator

Hello everyone:

To validate the kernel's effectiveness in real-world serving scenarios, we benchmarked using NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 on 4x Intel B60 GPUs with vLLM (tensor-parallel-size=4). This model has 52 decoder layers, of which 23 are MoE layers with the following configuration:

  • num_experts = 128
  • n_group = 1
  • topk_group = 1
  • topk = 8
  • scoring_func = sigmoid

Benchmark Setup

All benchmarks use --max-concurrency 2 and --request-rate inf to simulate realistic serving conditions.

Results

TTFT Benchmark (prefill-dominant)

  • Input length: 8192, Output length: 64, 100 prompts

Metric Baseline Optimized Improvement
Mean TTFT (ms) 2340.88 2109.38 -9.9%
Median TTFT (ms) 2339.16 2110.56 -9.8%
P99 TTFT (ms) 2379.89 2116.13 -11.1%

Decode Throughput Benchmark (decode-dominant)

  • Input length: 128, Output length: 512, 50 prompts

Metric Baseline Optimized Improvement
Median TPOT (ms) 49.17 47.16 -4.1%
Output throughput (tok/s) 40.34 41.32 +2.4%
Total throughput (tok/s) 50.43 51.65 +2.4%

Summary

The optimized kernel achieves:

  • ~10% TTFT reduction on long-context prefill (8K tokens)
  • ~2-4% decode throughput improvement

original results are as follows: image

BUG Report

In addition, I also discovered a bug in the current kernel version. When expert >= 256, for example kimi expert = 384, at this time the kernel will have an array out-of-bounds access problem

for (int e = 0; e < calc_per_item; ++e) {
      load_elems[e] = kNegInfinity;
      local_idx[e] = -1;
      bias[e] = 0.0f;  // Initialize bias to zero
    }

In this loop, the number of iterations is calc_per_item,int calc_per_item = (experts + sub_group_size - 1) / sub_group_size;,with the default subgroup=32, and calc_per_item is 12 when expert=384

T load_elems[malloc_per_item];
    int local_idx[malloc_per_item];
    T bias[malloc_per_item];

But by definition, the size of load_elems is malloc_per_item. static constexpr int malloc_per_item = MAX_EXPERT_GROUPS;

image When expert_group=1, max_expert_group=8, so malloc_per_item is also 8. So the code will access load_elems[9], [10], [11], thus causing an out-of-bounds error

can you also update the single kernel performance improvement?

Copy link
Copy Markdown
Collaborator

@xinyu-intel xinyu-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the contribution. I find free function kernel has been introduced via this PR, but we are still using cgh.parallel_for for kernel submission. We can consider to move existing kernels to free function kernel as well as using sycl::nd_launch for the submission after 2026.0 upgrading. @jikunshang

Comment thread tests/test_grouped_topk.py
@xiaolong-intel
Copy link
Copy Markdown
Author

xiaolong-intel commented May 27, 2026

Hello everyone:
To validate the kernel's effectiveness in real-world serving scenarios, we benchmarked using NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 on 4x Intel B60 GPUs with vLLM (tensor-parallel-size=4). This model has 52 decoder layers, of which 23 are MoE layers with the following configuration:

  • num_experts = 128
  • n_group = 1
  • topk_group = 1
  • topk = 8
  • scoring_func = sigmoid

Benchmark Setup

All benchmarks use --max-concurrency 2 and --request-rate inf to simulate realistic serving conditions.

Results

TTFT Benchmark (prefill-dominant)

  • Input length: 8192, Output length: 64, 100 prompts

Metric Baseline Optimized Improvement
Mean TTFT (ms) 2340.88 2109.38 -9.9%
Median TTFT (ms) 2339.16 2110.56 -9.8%
P99 TTFT (ms) 2379.89 2116.13 -11.1%

Decode Throughput Benchmark (decode-dominant)

  • Input length: 128, Output length: 512, 50 prompts

Metric Baseline Optimized Improvement
Median TPOT (ms) 49.17 47.16 -4.1%
Output throughput (tok/s) 40.34 41.32 +2.4%
Total throughput (tok/s) 50.43 51.65 +2.4%

Summary

The optimized kernel achieves:

  • ~10% TTFT reduction on long-context prefill (8K tokens)
  • ~2-4% decode throughput improvement

original results are as follows: image

BUG Report

In addition, I also discovered a bug in the current kernel version. When expert >= 256, for example kimi expert = 384, at this time the kernel will have an array out-of-bounds access problem

for (int e = 0; e < calc_per_item; ++e) {
      load_elems[e] = kNegInfinity;
      local_idx[e] = -1;
      bias[e] = 0.0f;  // Initialize bias to zero
    }

In this loop, the number of iterations is calc_per_item,int calc_per_item = (experts + sub_group_size - 1) / sub_group_size;,with the default subgroup=32, and calc_per_item is 12 when expert=384

T load_elems[malloc_per_item];
    int local_idx[malloc_per_item];
    T bias[malloc_per_item];

But by definition, the size of load_elems is malloc_per_item. static constexpr int malloc_per_item = MAX_EXPERT_GROUPS;
image When expert_group=1, max_expert_group=8, so malloc_per_item is also 8. So the code will access load_elems[9], [10], [11], thus causing an out-of-bounds error

can you also update the single kernel performance improvement?

Okay. I captured the trace during nemotron inference and can see that the kernel time dropped from 76us to 17us.
image
image
In addition, I also tested the performance of the single kernel under different combinations(compared with original fused_grouped_topk). Results are as follows:

GLM condition
expert=256 group=1 topk_group=1 topk=8
Case small=256 medium=8192 large=50000
[case=small] topk=8 group=1/1 fused_grouped_topk_original_avg_us=170.335 fused_grouped_topk_optimize_avg_us=132.608 speedup(fused_grouped_topk_original/fused_grouped_topk_optimize)=1.285x
[case=medium] topk=8 group=1/1 fused_grouped_topk_original_avg_us=634.793 fused_grouped_topk_optimize_avg_us=259.923 speedup(fused_grouped_topk_original/fused_grouped_topk_optimize)=2.442x
[case=large] topk=8 group=1/1 fused_grouped_topk_original_avg_us=3245.944 fused_grouped_topk_optimize_avg_us=1206.317 speedup(fused_grouped_topk_original/fused_grouped_topk_optimize)=2.691x
 
 
DS condition
expert=256 group=8 topk_group=4 topk=8
Case small=256 medium=8192 large=50000
[case=small] topk=8 group=8/4 fused_grouped_topk_original_avg_us=182.549 fused_grouped_topk_optimize_avg_us=132.161 speedup(fused_grouped_topk_original/fused_grouped_topk_optimize)=1.381x
[case=medium] topk=8 group=8/4 fused_grouped_topk_original_avg_us=665.981 fused_grouped_topk_optimize_avg_us=247.630 speedup(fused_grouped_topk_original/fused_grouped_topk_optimize)=2.689x
[case=large] topk=8 group=8/4 fused_grouped_topk_original_avg_us=3439.280 fused_grouped_topk_optimize_avg_us=1060.260 speedup(fused_grouped_topk_original/fused_grouped_topk_optimize)=3.244x
 
Nemotron condition
expert=128 group=1 topk_group=1 topk=6
Case small=256 medium=8192 large=50000
[case=small] topk=6 group=1/1 fused_grouped_topk_original_avg_us=167.064 fused_grouped_topk_optimize_avg_us=140.846 speedup(fused_grouped_topk_original/fused_grouped_topk_optimize)=1.186x
[case=medium] topk=6 group=1/1 fused_grouped_topk_original_avg_us=524.293 fused_grouped_topk_optimize_avg_us=224.320 speedup(fused_grouped_topk_original/fused_grouped_topk_optimize)=2.337x
[case=large] topk=6 group=1/1 fused_grouped_topk_original_avg_us=2636.241 fused_grouped_topk_optimize_avg_us=997.623 speedup(fused_grouped_topk_original/fused_grouped_topk_optimize)=2.643x

@xiaolong-intel xiaolong-intel requested a review from jikunshang May 28, 2026 02:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants